{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# general imports\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "# torch imports\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "import torch.optim.lr_scheduler as schedulers\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "from torchsummary import summary\n",
    "\n",
    "# nupic research imports\n",
    "from nupic.research.frameworks.pytorch.image_transforms import RandomNoise\n",
    "from nupic.torch.modules import KWinners\n",
    "\n",
    "PATH_TO_WHERE_DATASET_WILL_BE_SAVED = PATH = \"~/nta/datasets\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Dataset:\n",
    "    \"\"\"Loads a dataset.\n",
    "    Returns object with a pytorch train and test loader\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, config=None):\n",
    "\n",
    "        defaults = dict(\n",
    "            dataset_name='MNIST',\n",
    "            data_dir=os.path.expanduser(PATH),\n",
    "            batch_size_train=128,\n",
    "            batch_size_test=128,\n",
    "            stats_mean=None,\n",
    "            stats_std=None,\n",
    "            augment_images=False,\n",
    "            test_noise=False,\n",
    "            noise_level=0.1,\n",
    "        )\n",
    "        defaults.update(config or {})\n",
    "        self.__dict__.update(defaults)\n",
    "        \n",
    "        # recover mean and std to normalize dataset\n",
    "        if not self.stats_mean or not self.stats_std:\n",
    "            tempset = getattr(datasets, self.dataset_name)(\n",
    "                root=self.data_dir, train=True, transform=transforms.ToTensor()\n",
    "            )\n",
    "            self.stats_mean = (tempset.data.float().mean().item() / 255,)\n",
    "            self.stats_std = (tempset.data.float().std().item() / 255,)\n",
    "            del tempset\n",
    "\n",
    "        # set up transformations\n",
    "        transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Normalize(self.stats_mean, self.stats_std),\n",
    "            ]\n",
    "        )\n",
    "        \n",
    "        # set up augment transforms for training\n",
    "        if not self.augment_images:\n",
    "            aug_transform = transform\n",
    "        else:\n",
    "            aug_transform = transforms.Compose(\n",
    "                [\n",
    "                    transforms.RandomCrop(32, padding=4),\n",
    "                    transforms.RandomHorizontalFlip(),\n",
    "                    transforms.ToTensor(),\n",
    "                    transforms.Normalize(self.stats_mean, self.stats_std),\n",
    "                ]\n",
    "            )\n",
    "\n",
    "        # load train set\n",
    "        train_set = getattr(datasets, self.dataset_name)(\n",
    "            root=self.data_dir, train=True, transform=aug_transform, download=True\n",
    "        )\n",
    "        self.train_loader = DataLoader(\n",
    "            dataset=train_set, batch_size=self.batch_size_train, shuffle=True\n",
    "        )\n",
    "\n",
    "        # load test set\n",
    "        test_set = getattr(datasets, self.dataset_name)(\n",
    "            root=self.data_dir, train=False, transform=transform, download=True\n",
    "        )\n",
    "        self.test_loader = DataLoader(\n",
    "            dataset=test_set, batch_size=self.batch_size_test, shuffle=False\n",
    "        )\n",
    "\n",
    "        # noise dataset\n",
    "        noise = self.noise_level\n",
    "        noise_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Normalize(self.stats_mean, self.stats_std),\n",
    "                RandomNoise(\n",
    "                    noise, high_value=0.5 + 2 * 0.20, low_value=0.5 - 2 * 0.2\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        noise_set = getattr(datasets, self.dataset_name)(\n",
    "            root=self.data_dir, train=False, transform=noise_transform\n",
    "        )\n",
    "        self.noise_loader = DataLoader(\n",
    "            dataset=noise_set, batch_size=self.batch_size_test, shuffle=False\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = Dataset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Build Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    \"\"\"Simple 3 hidden layers + output MLP\"\"\"\n",
    "\n",
    "    def __init__(self, config=None):\n",
    "        super(MLP, self).__init__()\n",
    "\n",
    "        defaults = dict(\n",
    "            device='cpu',\n",
    "            input_size=784,\n",
    "            num_classes=10,\n",
    "            hidden_sizes=[400, 400, 400],\n",
    "            batch_norm=False,\n",
    "            dropout=False,\n",
    "            use_kwinners=False\n",
    "        )\n",
    "        defaults.update(config or {})\n",
    "        self.__dict__.update(defaults)\n",
    "        self.device = torch.device(self.device)\n",
    "        \n",
    "        # decide which actiovation function to use\n",
    "        if self.use_kwinners: self.activation_func = self._kwinners\n",
    "        else: self.activation_func = lambda _: nn.ReLU()\n",
    "\n",
    "        # create the layers\n",
    "        layers = [\n",
    "            *self._linear_block(self.input_size, self.hidden_sizes[0]),\n",
    "            *self._linear_block(self.hidden_sizes[0],  self.hidden_sizes[1]),\n",
    "            *self._linear_block(self.hidden_sizes[1], self.hidden_sizes[2]),\n",
    "            nn.Linear(self.hidden_sizes[2], self.num_classes),\n",
    "        ]\n",
    "        self.classifier = nn.Sequential(*layers)\n",
    "\n",
    "    def _linear_block(self, a, b):\n",
    "        block = [nn.Linear(a, b), self.activation_func(b)]\n",
    "        if self.batch_norm: block.append(nn.BatchNorm1d(b))\n",
    "        if self.dropout: block.append(nn.Dropout(p=self.dropout))\n",
    "        return block\n",
    "\n",
    "    def _kwinners(self, num_units):\n",
    "        return KWinners(\n",
    "            n=num_units,\n",
    "            percent_on=0.3,\n",
    "            boost_strength=1.4,\n",
    "            boost_strength_factor=0.7,\n",
    "        )\n",
    "    \n",
    "    def forward(self, x):\n",
    "        # need to flatten input before forward pass\n",
    "        return self.classifier(x.view(-1, self.input_size))\n",
    "    \n",
    "    def alternative_forward(self, x):\n",
    "        \"\"\"Replace forward function by this to visualize activations\"\"\"\n",
    "        # need to flatten before forward pass\n",
    "        x = x.view(-1, self.input_size)\n",
    "        for layer in self.classifier:\n",
    "            # apply the transformation\n",
    "            x = layer(x)\n",
    "            # do something with the activation\n",
    "            print(torch.mean(x).item())\n",
    "        \n",
    "        return x\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "network = MLP()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Linear-1                  [-1, 400]         314,000\n",
      "              ReLU-2                  [-1, 400]               0\n",
      "            Linear-3                  [-1, 400]         160,400\n",
      "              ReLU-4                  [-1, 400]               0\n",
      "            Linear-5                  [-1, 400]         160,400\n",
      "              ReLU-6                  [-1, 400]               0\n",
      "            Linear-7                   [-1, 10]           4,010\n",
      "================================================================\n",
      "Total params: 638,810\n",
      "Trainable params: 638,810\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.00\n",
      "Forward/backward pass size (MB): 0.02\n",
      "Params size (MB): 2.44\n",
      "Estimated Total Size (MB): 2.46\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "summary(network, input_size=(1,28,28))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Build Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class BaseModel:\n",
    "    \"\"\"Base model, with training loops and logging functions.\"\"\"\n",
    "\n",
    "    def __init__(self, network, dataset, config=None):\n",
    "        defaults = dict(\n",
    "            learning_rate=0.1,\n",
    "            momentum=0.9,\n",
    "            device=\"cpu\",\n",
    "            lr_scheduler=False,\n",
    "            sparsity=0,\n",
    "            weight_decay=1e-4,\n",
    "            test_noise=False\n",
    "        )\n",
    "        defaults.update(config or {})\n",
    "        self.__dict__.update(defaults)\n",
    "\n",
    "        self.device = torch.device(self.device)\n",
    "        self.network = network.to(self.device)\n",
    "        self.dataset = dataset\n",
    "        self.setup()\n",
    "    \n",
    "        # apply sparsity\n",
    "        if self.sparsity:\n",
    "            self.sparse_layers = []\n",
    "            for layer in self.network.modules():\n",
    "                if isinstance(layer, nn.Linear):\n",
    "                    shape = layer.weight.shape\n",
    "                    mask = (torch.rand(shape) > self.sparsity).float().to(self.device)\n",
    "                    layer.weight.data *= mask\n",
    "                    self.sparse_layers.append((layer, mask))\n",
    "\n",
    "    def setup(self):\n",
    "        self.optimizer = optim.SGD(\n",
    "            self.network.parameters(),\n",
    "            lr=self.learning_rate,\n",
    "            momentum=self.momentum,\n",
    "            weight_decay=self.weight_decay,\n",
    "        )\n",
    "\n",
    "        # add a learning rate scheduler\n",
    "        if self.lr_scheduler:\n",
    "            milestones = [int(self.num_epochs/2), int(num_epochs*3/4)]\n",
    "            self.lr_scheduler = schedulers.MultiStepLR(\n",
    "                self.optimizer, milestones=milestones, gamma=0.1\n",
    "            )\n",
    "\n",
    "        # init loss function\n",
    "        self.loss_func = nn.CrossEntropyLoss()\n",
    "        self.epoch = 0\n",
    "    \n",
    "    def train(self, num_epochs, test_noise=False):\n",
    "        for i in range(num_epochs):\n",
    "            log = self.run_epoch(test_noise)\n",
    "            # print acc\n",
    "            if test_noise:\n",
    "                print(\"Train acc: {:.4f}, Val acc: {:.4f}, Noise acc: {:.4f}\".format(\n",
    "                    log['train_acc'], log['val_acc'], log['test_acc']))\n",
    "            else:\n",
    "                print(\"Train acc: {:.4f}, Val acc: {:.4f}\".format(\n",
    "                    log['train_acc'], log['val_acc']))\n",
    "        \n",
    "    def run_epoch(self, test_noise):\n",
    "\n",
    "        log = {}        \n",
    "        self.epoch += 1\n",
    "        log['current_epoch'] = self.epoch\n",
    "        \n",
    "        # train\n",
    "        self.network.train()\n",
    "        log['train_loss'], log['train_acc'] = \\\n",
    "            self._run_one_pass(self.dataset.train_loader, train=True)\n",
    "        # validate\n",
    "        self.network.eval()\n",
    "        log['val_loss'], log['val_acc'] = \\\n",
    "            self._run_one_pass(self.dataset.test_loader, train=False)\n",
    "        # additional validation for noise\n",
    "        if test_noise:\n",
    "            log['test_loss'], log['test_acc'] = \\\n",
    "                self._run_one_pass(self.dataset.noise_loader, train=False, noise=True)\n",
    "\n",
    "        # any updates post training, e.g. scheduler\n",
    "        self._post_epoch_updates(dataset)\n",
    "        \n",
    "        return log\n",
    "\n",
    "    def _post_epoch_updates(self, dataset=None):\n",
    "        # update learning rate\n",
    "        if self.lr_scheduler:\n",
    "            self.lr_scheduler.step()\n",
    "\n",
    "    def _run_one_pass(self, loader, train=True, noise=False):\n",
    "\n",
    "        epoch_loss = 0\n",
    "        correct = 0\n",
    "        for inputs, targets in loader:\n",
    "            # setup for training\n",
    "            inputs = inputs.to(self.device)\n",
    "            targets = targets.to(self.device)\n",
    "            self.optimizer.zero_grad()\n",
    "            # training loop\n",
    "            with torch.set_grad_enabled(train):\n",
    "                # forward + backward + optimize\n",
    "                outputs = self.network(inputs)\n",
    "                _, preds = torch.max(outputs, 1)\n",
    "                correct += torch.sum(targets == preds).item()\n",
    "                loss = self.loss_func(outputs, targets)\n",
    "                if train:\n",
    "                    loss.backward()\n",
    "                    self.optimizer.step()\n",
    "                    # if sparse, apply the mask to weights after optimization\n",
    "                    if self.sparsity:\n",
    "                        for layer, mask in self.sparse_layers:\n",
    "                            layer.weight.data *= mask\n",
    "\n",
    "            # keep track of loss\n",
    "            epoch_loss += loss.item() * inputs.size(0)\n",
    "\n",
    "        # store loss and acc at each pass\n",
    "        loss = epoch_loss / len(loader.dataset)\n",
    "        acc = correct / len(loader.dataset)\n",
    "\n",
    "        return loss, acc\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = BaseModel(network, dataset, dict(k_winners=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Train regular dense network with ReLU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train acc: 0.9137, Val acc: 0.9685, Noise acc: 0.9643\n",
      "Train acc: 0.9705, Val acc: 0.9721, Noise acc: 0.9678\n",
      "Train acc: 0.9780, Val acc: 0.9727, Noise acc: 0.9674\n",
      "Train acc: 0.9830, Val acc: 0.9736, Noise acc: 0.9708\n",
      "Train acc: 0.9856, Val acc: 0.9733, Noise acc: 0.9694\n"
     ]
    }
   ],
   "source": [
    "model.train(num_epochs=5, test_noise=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. KWinnners instead of ReLU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train acc: 0.9096, Val acc: 0.9626, Noise acc: 0.9600\n",
      "Train acc: 0.9673, Val acc: 0.9682, Noise acc: 0.9658\n",
      "Train acc: 0.9780, Val acc: 0.9753, Noise acc: 0.9724\n",
      "Train acc: 0.9820, Val acc: 0.9752, Noise acc: 0.9729\n",
      "Train acc: 0.9859, Val acc: 0.9792, Noise acc: 0.9774\n"
     ]
    }
   ],
   "source": [
    "# rebuild the network with KWinners as the activation function\n",
    "network = MLP(dict(kwinners=True))\n",
    "# build model\n",
    "model = BaseModel(network, dataset, dict(k_winners=True))\n",
    "# run model\n",
    "model.train(num_epochs=5, test_noise=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. KWinners + Sparse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train acc: 0.7464, Val acc: 0.9506, Noise acc: 0.9453\n",
      "Train acc: 0.9638, Val acc: 0.9687, Noise acc: 0.9620\n",
      "Train acc: 0.9754, Val acc: 0.9745, Noise acc: 0.9669\n",
      "Train acc: 0.9819, Val acc: 0.9771, Noise acc: 0.9720\n",
      "Train acc: 0.9858, Val acc: 0.9748, Noise acc: 0.9643\n"
     ]
    }
   ],
   "source": [
    "# rebuild the network with KWinners as the activation function\n",
    "network = MLP(dict(kwinners=True))\n",
    "# build model\n",
    "model = BaseModel(network, dataset, \n",
    "                  dict(k_winners=True, sparsity=0.8))\n",
    "# run model\n",
    "model.train(num_epochs=5, test_noise=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Inspecting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Inspect weights after training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-0.00037051260005682707 0.031048087403178215\n",
      "0.00023493298795074224 0.029837530106306076\n",
      "0.0005359541974030435 0.023514844477176666\n",
      "0.0012203083606436849 0.11726776510477066\n"
     ]
    }
   ],
   "source": [
    "for layer in network.modules():\n",
    "    if isinstance(layer, nn.Linear):\n",
    "        print(torch.mean(layer.weight).item(), torch.std(layer.weight).item())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Inspect activations during the forward pass (can use it to inspect during training as well)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-0.18638744950294495\n",
      "0.17727136611938477\n",
      "0.017351068556308746\n",
      "0.10145770758390427\n",
      "0.05837811529636383\n",
      "0.11567458510398865\n",
      "-1.3849645853042603\n"
     ]
    }
   ],
   "source": [
    "network.forward = network.alternative_forward\n",
    "network(torch.rand(1,1,28,28));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
